opentau 0.1.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- opentau/configs/default.py +16 -0
- opentau/configs/deployment.py +85 -0
- opentau/configs/train.py +5 -0
- opentau/datasets/factory.py +43 -10
- opentau/datasets/lerobot_dataset.py +19 -19
- opentau/datasets/video_utils.py +11 -6
- opentau/policies/pi05/configuration_pi05.py +9 -6
- opentau/policies/pi05/modeling_pi05.py +296 -30
- opentau/policies/pi05/paligemma_with_expert.py +20 -20
- opentau/scripts/grpc/__init__.py +19 -0
- opentau/scripts/grpc/client.py +601 -0
- opentau/scripts/grpc/robot_inference_pb2.py +61 -0
- opentau/scripts/grpc/robot_inference_pb2_grpc.py +210 -0
- opentau/scripts/grpc/server.py +313 -0
- opentau/scripts/launch.py +12 -4
- opentau/scripts/train.py +94 -17
- opentau/scripts/visualize_dataset.py +141 -38
- opentau/utils/transformers_patch.py +251 -20
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/METADATA +37 -17
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/RECORD +24 -21
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/WHEEL +1 -1
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/entry_points.txt +1 -0
- opentau/scripts/libero_simulation_parallel.py +0 -356
- opentau/scripts/libero_simulation_sequential.py +0 -122
- opentau/scripts/visualize_dataset_html.py +0 -507
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/licenses/LICENSE +0 -0
- {opentau-0.1.1.dist-info → opentau-0.2.0.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
|
|
16
|
+
"""Client and server classes corresponding to protobuf-defined services."""
|
|
17
|
+
|
|
18
|
+
import grpc
|
|
19
|
+
from opentau.scripts.grpc import robot_inference_pb2 as robot__inference__pb2
|
|
20
|
+
|
|
21
|
+
GRPC_GENERATED_VERSION = "1.76.0"
|
|
22
|
+
GRPC_VERSION = grpc.__version__
|
|
23
|
+
_version_not_supported = False
|
|
24
|
+
|
|
25
|
+
try:
|
|
26
|
+
from grpc._utilities import first_version_is_lower
|
|
27
|
+
|
|
28
|
+
_version_not_supported = first_version_is_lower(GRPC_VERSION, GRPC_GENERATED_VERSION)
|
|
29
|
+
except ImportError:
|
|
30
|
+
_version_not_supported = True
|
|
31
|
+
|
|
32
|
+
if _version_not_supported:
|
|
33
|
+
raise RuntimeError(
|
|
34
|
+
f"The grpc package installed is at version {GRPC_VERSION},"
|
|
35
|
+
+ " but the generated code in robot_inference_pb2_grpc.py depends on"
|
|
36
|
+
+ f" grpcio>={GRPC_GENERATED_VERSION}."
|
|
37
|
+
+ f" Please upgrade your grpc module to grpcio>={GRPC_GENERATED_VERSION}"
|
|
38
|
+
+ f" or downgrade your generated code using grpcio-tools<={GRPC_VERSION}."
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class RobotPolicyServiceStub:
|
|
43
|
+
"""Service for robot policy inference"""
|
|
44
|
+
|
|
45
|
+
def __init__(self, channel):
|
|
46
|
+
"""Constructor.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
channel: A grpc.Channel.
|
|
50
|
+
"""
|
|
51
|
+
self.GetActionChunk = channel.unary_unary(
|
|
52
|
+
"/robot_inference.RobotPolicyService/GetActionChunk",
|
|
53
|
+
request_serializer=robot__inference__pb2.ObservationRequest.SerializeToString,
|
|
54
|
+
response_deserializer=robot__inference__pb2.ActionChunkResponse.FromString,
|
|
55
|
+
_registered_method=True,
|
|
56
|
+
)
|
|
57
|
+
self.StreamActionChunks = channel.stream_stream(
|
|
58
|
+
"/robot_inference.RobotPolicyService/StreamActionChunks",
|
|
59
|
+
request_serializer=robot__inference__pb2.ObservationRequest.SerializeToString,
|
|
60
|
+
response_deserializer=robot__inference__pb2.ActionChunkResponse.FromString,
|
|
61
|
+
_registered_method=True,
|
|
62
|
+
)
|
|
63
|
+
self.HealthCheck = channel.unary_unary(
|
|
64
|
+
"/robot_inference.RobotPolicyService/HealthCheck",
|
|
65
|
+
request_serializer=robot__inference__pb2.HealthCheckRequest.SerializeToString,
|
|
66
|
+
response_deserializer=robot__inference__pb2.HealthCheckResponse.FromString,
|
|
67
|
+
_registered_method=True,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
class RobotPolicyServiceServicer:
|
|
72
|
+
"""Service for robot policy inference"""
|
|
73
|
+
|
|
74
|
+
def GetActionChunk(self, request, context):
|
|
75
|
+
"""Get action chunk from observations (single request-response)"""
|
|
76
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
77
|
+
context.set_details("Method not implemented!")
|
|
78
|
+
raise NotImplementedError("Method not implemented!")
|
|
79
|
+
|
|
80
|
+
def StreamActionChunks(self, request_iterator, context):
|
|
81
|
+
"""Streaming version for continuous inference (robot sends observations, server sends action chunks)"""
|
|
82
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
83
|
+
context.set_details("Method not implemented!")
|
|
84
|
+
raise NotImplementedError("Method not implemented!")
|
|
85
|
+
|
|
86
|
+
def HealthCheck(self, request, context):
|
|
87
|
+
"""Health check"""
|
|
88
|
+
context.set_code(grpc.StatusCode.UNIMPLEMENTED)
|
|
89
|
+
context.set_details("Method not implemented!")
|
|
90
|
+
raise NotImplementedError("Method not implemented!")
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def add_RobotPolicyServiceServicer_to_server(servicer, server):
|
|
94
|
+
rpc_method_handlers = {
|
|
95
|
+
"GetActionChunk": grpc.unary_unary_rpc_method_handler(
|
|
96
|
+
servicer.GetActionChunk,
|
|
97
|
+
request_deserializer=robot__inference__pb2.ObservationRequest.FromString,
|
|
98
|
+
response_serializer=robot__inference__pb2.ActionChunkResponse.SerializeToString,
|
|
99
|
+
),
|
|
100
|
+
"StreamActionChunks": grpc.stream_stream_rpc_method_handler(
|
|
101
|
+
servicer.StreamActionChunks,
|
|
102
|
+
request_deserializer=robot__inference__pb2.ObservationRequest.FromString,
|
|
103
|
+
response_serializer=robot__inference__pb2.ActionChunkResponse.SerializeToString,
|
|
104
|
+
),
|
|
105
|
+
"HealthCheck": grpc.unary_unary_rpc_method_handler(
|
|
106
|
+
servicer.HealthCheck,
|
|
107
|
+
request_deserializer=robot__inference__pb2.HealthCheckRequest.FromString,
|
|
108
|
+
response_serializer=robot__inference__pb2.HealthCheckResponse.SerializeToString,
|
|
109
|
+
),
|
|
110
|
+
}
|
|
111
|
+
generic_handler = grpc.method_handlers_generic_handler(
|
|
112
|
+
"robot_inference.RobotPolicyService", rpc_method_handlers
|
|
113
|
+
)
|
|
114
|
+
server.add_generic_rpc_handlers((generic_handler,))
|
|
115
|
+
server.add_registered_method_handlers("robot_inference.RobotPolicyService", rpc_method_handlers)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
# This class is part of an EXPERIMENTAL API.
|
|
119
|
+
class RobotPolicyService:
|
|
120
|
+
"""Service for robot policy inference"""
|
|
121
|
+
|
|
122
|
+
@staticmethod
|
|
123
|
+
def GetActionChunk(
|
|
124
|
+
request,
|
|
125
|
+
target,
|
|
126
|
+
options=(),
|
|
127
|
+
channel_credentials=None,
|
|
128
|
+
call_credentials=None,
|
|
129
|
+
insecure=False,
|
|
130
|
+
compression=None,
|
|
131
|
+
wait_for_ready=None,
|
|
132
|
+
timeout=None,
|
|
133
|
+
metadata=None,
|
|
134
|
+
):
|
|
135
|
+
return grpc.experimental.unary_unary(
|
|
136
|
+
request,
|
|
137
|
+
target,
|
|
138
|
+
"/robot_inference.RobotPolicyService/GetActionChunk",
|
|
139
|
+
robot__inference__pb2.ObservationRequest.SerializeToString,
|
|
140
|
+
robot__inference__pb2.ActionChunkResponse.FromString,
|
|
141
|
+
options,
|
|
142
|
+
channel_credentials,
|
|
143
|
+
insecure,
|
|
144
|
+
call_credentials,
|
|
145
|
+
compression,
|
|
146
|
+
wait_for_ready,
|
|
147
|
+
timeout,
|
|
148
|
+
metadata,
|
|
149
|
+
_registered_method=True,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
@staticmethod
|
|
153
|
+
def StreamActionChunks(
|
|
154
|
+
request_iterator,
|
|
155
|
+
target,
|
|
156
|
+
options=(),
|
|
157
|
+
channel_credentials=None,
|
|
158
|
+
call_credentials=None,
|
|
159
|
+
insecure=False,
|
|
160
|
+
compression=None,
|
|
161
|
+
wait_for_ready=None,
|
|
162
|
+
timeout=None,
|
|
163
|
+
metadata=None,
|
|
164
|
+
):
|
|
165
|
+
return grpc.experimental.stream_stream(
|
|
166
|
+
request_iterator,
|
|
167
|
+
target,
|
|
168
|
+
"/robot_inference.RobotPolicyService/StreamActionChunks",
|
|
169
|
+
robot__inference__pb2.ObservationRequest.SerializeToString,
|
|
170
|
+
robot__inference__pb2.ActionChunkResponse.FromString,
|
|
171
|
+
options,
|
|
172
|
+
channel_credentials,
|
|
173
|
+
insecure,
|
|
174
|
+
call_credentials,
|
|
175
|
+
compression,
|
|
176
|
+
wait_for_ready,
|
|
177
|
+
timeout,
|
|
178
|
+
metadata,
|
|
179
|
+
_registered_method=True,
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
@staticmethod
|
|
183
|
+
def HealthCheck(
|
|
184
|
+
request,
|
|
185
|
+
target,
|
|
186
|
+
options=(),
|
|
187
|
+
channel_credentials=None,
|
|
188
|
+
call_credentials=None,
|
|
189
|
+
insecure=False,
|
|
190
|
+
compression=None,
|
|
191
|
+
wait_for_ready=None,
|
|
192
|
+
timeout=None,
|
|
193
|
+
metadata=None,
|
|
194
|
+
):
|
|
195
|
+
return grpc.experimental.unary_unary(
|
|
196
|
+
request,
|
|
197
|
+
target,
|
|
198
|
+
"/robot_inference.RobotPolicyService/HealthCheck",
|
|
199
|
+
robot__inference__pb2.HealthCheckRequest.SerializeToString,
|
|
200
|
+
robot__inference__pb2.HealthCheckResponse.FromString,
|
|
201
|
+
options,
|
|
202
|
+
channel_credentials,
|
|
203
|
+
insecure,
|
|
204
|
+
call_credentials,
|
|
205
|
+
compression,
|
|
206
|
+
wait_for_ready,
|
|
207
|
+
timeout,
|
|
208
|
+
metadata,
|
|
209
|
+
_registered_method=True,
|
|
210
|
+
)
|
|
@@ -0,0 +1,313 @@
|
|
|
1
|
+
# Copyright 2026 Tensor Auto Inc. All rights reserved.
|
|
2
|
+
#
|
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
4
|
+
# you may not use this file except in compliance with the License.
|
|
5
|
+
# You may obtain a copy of the License at
|
|
6
|
+
#
|
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
8
|
+
#
|
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
12
|
+
# See the License for the specific language governing permissions and
|
|
13
|
+
# limitations under the License.
|
|
14
|
+
|
|
15
|
+
"""gRPC server for robot policy inference on GPU.
|
|
16
|
+
|
|
17
|
+
This server loads an ML policy model and serves inference requests from
|
|
18
|
+
robots running ROS 2. Designed to run on a server with a ML GPU.
|
|
19
|
+
|
|
20
|
+
Usage:
|
|
21
|
+
python src/opentau/scripts/grpc/server.py --config_path=/path/to/config.json \\
|
|
22
|
+
--server.port=50051 --server.max_workers=4
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import io
|
|
26
|
+
import logging
|
|
27
|
+
from concurrent import futures
|
|
28
|
+
from dataclasses import asdict
|
|
29
|
+
from pprint import pformat
|
|
30
|
+
from typing import Iterator
|
|
31
|
+
|
|
32
|
+
import numpy as np
|
|
33
|
+
import torch
|
|
34
|
+
from PIL import Image
|
|
35
|
+
|
|
36
|
+
import grpc
|
|
37
|
+
from opentau.configs import parser
|
|
38
|
+
from opentau.configs.train import TrainPipelineConfig
|
|
39
|
+
from opentau.policies.factory import get_policy_class
|
|
40
|
+
from opentau.scripts.grpc import robot_inference_pb2, robot_inference_pb2_grpc
|
|
41
|
+
from opentau.utils.random_utils import set_seed
|
|
42
|
+
from opentau.utils.utils import (
|
|
43
|
+
attempt_torch_compile,
|
|
44
|
+
auto_torch_device,
|
|
45
|
+
init_logging,
|
|
46
|
+
)
|
|
47
|
+
|
|
48
|
+
logger = logging.getLogger(__name__)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class RobotPolicyServicer(robot_inference_pb2_grpc.RobotPolicyServiceServicer):
|
|
52
|
+
"""gRPC servicer implementing the RobotPolicyService."""
|
|
53
|
+
|
|
54
|
+
def __init__(self, cfg: TrainPipelineConfig):
|
|
55
|
+
"""Initialize the servicer with model and configuration.
|
|
56
|
+
|
|
57
|
+
Args:
|
|
58
|
+
cfg: Training pipeline configuration including policy settings.
|
|
59
|
+
"""
|
|
60
|
+
self.cfg = cfg
|
|
61
|
+
self.device = auto_torch_device()
|
|
62
|
+
self.dtype = torch.bfloat16
|
|
63
|
+
|
|
64
|
+
logger.info(f"Initializing policy on device: {self.device}")
|
|
65
|
+
|
|
66
|
+
# Load the policy model
|
|
67
|
+
self._load_policy()
|
|
68
|
+
|
|
69
|
+
def _load_policy(self):
|
|
70
|
+
"""Load the policy model from pretrained weights."""
|
|
71
|
+
logger.info(f"Loading policy from: {self.cfg.policy.pretrained_path}")
|
|
72
|
+
|
|
73
|
+
policy_class = get_policy_class(self.cfg.policy.type)
|
|
74
|
+
self.policy = policy_class.from_pretrained(self.cfg.policy.pretrained_path, config=self.cfg.policy)
|
|
75
|
+
self.policy.to(device=self.device, dtype=self.dtype)
|
|
76
|
+
self.policy.eval()
|
|
77
|
+
self.policy = attempt_torch_compile(self.policy, device_hint=self.device)
|
|
78
|
+
self.policy.reset()
|
|
79
|
+
logger.info("Policy loaded successfully")
|
|
80
|
+
|
|
81
|
+
def _decode_image(self, camera_image: robot_inference_pb2.CameraImage) -> torch.Tensor:
|
|
82
|
+
"""Decode an image from the protobuf message.
|
|
83
|
+
|
|
84
|
+
Args:
|
|
85
|
+
camera_image: CameraImage protobuf message.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Tensor of shape (1, C, H, W) normalized to [0, 1].
|
|
89
|
+
"""
|
|
90
|
+
if camera_image.encoding in ["jpeg", "png"]:
|
|
91
|
+
# Decode compressed image
|
|
92
|
+
image = Image.open(io.BytesIO(camera_image.image_data))
|
|
93
|
+
image = image.convert("RGB")
|
|
94
|
+
image = image.resize(self.cfg.resolution[::-1]) # PIL uses (W, H)
|
|
95
|
+
image_array = np.array(image, dtype=np.float32) / 255.0
|
|
96
|
+
elif camera_image.encoding == "raw":
|
|
97
|
+
# Raw image data - assume it's already in the right shape
|
|
98
|
+
image_array = np.frombuffer(camera_image.image_data, dtype=np.float32)
|
|
99
|
+
# Reshape assuming square image with 3 channels
|
|
100
|
+
side = int(np.sqrt(len(image_array) / 3))
|
|
101
|
+
image_array = image_array.reshape(side, side, 3)
|
|
102
|
+
# Resize if needed
|
|
103
|
+
if (side, side) != self.cfg.resolution:
|
|
104
|
+
image = Image.fromarray((image_array * 255).astype(np.uint8))
|
|
105
|
+
image = image.resize(self.cfg.resolution[::-1])
|
|
106
|
+
image_array = np.array(image, dtype=np.float32) / 255.0
|
|
107
|
+
else:
|
|
108
|
+
raise ValueError(f"Unknown image encoding: {camera_image.encoding}")
|
|
109
|
+
|
|
110
|
+
# Convert to (C, H, W) tensor
|
|
111
|
+
image_tensor = torch.from_numpy(image_array).permute(2, 0, 1).unsqueeze(0)
|
|
112
|
+
return image_tensor.to(device=self.device, dtype=self.dtype)
|
|
113
|
+
|
|
114
|
+
def _prepare_observation(
|
|
115
|
+
self, request: robot_inference_pb2.ObservationRequest
|
|
116
|
+
) -> dict[str, torch.Tensor]:
|
|
117
|
+
"""Convert a protobuf observation request to the policy input format.
|
|
118
|
+
|
|
119
|
+
Args:
|
|
120
|
+
request: ObservationRequest protobuf message.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
Dictionary of tensors matching the policy's expected input format.
|
|
124
|
+
"""
|
|
125
|
+
batch = {}
|
|
126
|
+
|
|
127
|
+
# Process camera images
|
|
128
|
+
img_is_pad = []
|
|
129
|
+
for i, camera_image in enumerate(request.images):
|
|
130
|
+
camera_name = f"camera{i}"
|
|
131
|
+
batch[camera_name] = self._decode_image(camera_image)
|
|
132
|
+
img_is_pad.append(False)
|
|
133
|
+
|
|
134
|
+
# Fill in missing cameras with zeros
|
|
135
|
+
for i in range(len(request.images), self.cfg.num_cams):
|
|
136
|
+
batch[f"camera{i}"] = torch.zeros(
|
|
137
|
+
(1, 3, *self.cfg.resolution),
|
|
138
|
+
dtype=self.dtype,
|
|
139
|
+
device=self.device,
|
|
140
|
+
)
|
|
141
|
+
img_is_pad.append(True)
|
|
142
|
+
|
|
143
|
+
batch["img_is_pad"] = torch.tensor([img_is_pad], dtype=torch.bool, device=self.device)
|
|
144
|
+
|
|
145
|
+
# Process robot state
|
|
146
|
+
if request.robot_state.state:
|
|
147
|
+
state = list(request.robot_state.state)
|
|
148
|
+
# Pad to max_state_dim if needed
|
|
149
|
+
if len(state) < self.cfg.max_state_dim:
|
|
150
|
+
state.extend([0.0] * (self.cfg.max_state_dim - len(state)))
|
|
151
|
+
batch["state"] = torch.tensor(
|
|
152
|
+
[state[: self.cfg.max_state_dim]],
|
|
153
|
+
dtype=self.dtype,
|
|
154
|
+
device=self.device,
|
|
155
|
+
)
|
|
156
|
+
else:
|
|
157
|
+
raise ValueError("Robot state is required but was not provided in the request")
|
|
158
|
+
|
|
159
|
+
# Process prompt
|
|
160
|
+
batch["prompt"] = [request.prompt] if request.prompt else [""]
|
|
161
|
+
|
|
162
|
+
return batch
|
|
163
|
+
|
|
164
|
+
def GetActionChunk(
|
|
165
|
+
self,
|
|
166
|
+
request: robot_inference_pb2.ObservationRequest,
|
|
167
|
+
context: grpc.ServicerContext,
|
|
168
|
+
) -> robot_inference_pb2.ActionChunkResponse:
|
|
169
|
+
"""Handle a single action chunk inference request.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
request: ObservationRequest containing observations.
|
|
173
|
+
context: gRPC context.
|
|
174
|
+
|
|
175
|
+
Returns:
|
|
176
|
+
ActionChunkResponse containing the predicted action chunk.
|
|
177
|
+
"""
|
|
178
|
+
import time
|
|
179
|
+
|
|
180
|
+
start_time = time.perf_counter()
|
|
181
|
+
response = robot_inference_pb2.ActionChunkResponse()
|
|
182
|
+
response.request_id = request.request_id
|
|
183
|
+
response.timestamp_ns = time.time_ns()
|
|
184
|
+
|
|
185
|
+
try:
|
|
186
|
+
# Prepare observation batch
|
|
187
|
+
batch = self._prepare_observation(request)
|
|
188
|
+
|
|
189
|
+
# Run inference
|
|
190
|
+
with torch.inference_mode():
|
|
191
|
+
action_chunk = self.policy.sample_actions(batch)
|
|
192
|
+
# action_chunk shape: (n_action_steps, batch_size=1, action_dim)
|
|
193
|
+
# Remove batch dimension and convert to numpy
|
|
194
|
+
action_chunk = action_chunk.squeeze(1).to("cpu", torch.float32).numpy()
|
|
195
|
+
|
|
196
|
+
# Populate 2D action chunk structure
|
|
197
|
+
for action_vector in action_chunk:
|
|
198
|
+
action_vec_msg = robot_inference_pb2.ActionVector()
|
|
199
|
+
action_vec_msg.values.extend(action_vector.tolist())
|
|
200
|
+
response.action_chunk.append(action_vec_msg)
|
|
201
|
+
|
|
202
|
+
except ValueError as e:
|
|
203
|
+
# Invalid request (e.g., missing required fields)
|
|
204
|
+
logger.error(f"Invalid request: {e}")
|
|
205
|
+
context.abort(grpc.StatusCode.INVALID_ARGUMENT, str(e))
|
|
206
|
+
|
|
207
|
+
except Exception as e:
|
|
208
|
+
# Unexpected error during inference
|
|
209
|
+
logger.exception("Error during inference")
|
|
210
|
+
context.abort(grpc.StatusCode.INTERNAL, f"Inference error: {e}")
|
|
211
|
+
|
|
212
|
+
response.inference_time_ms = (time.perf_counter() - start_time) * 1000
|
|
213
|
+
return response
|
|
214
|
+
|
|
215
|
+
def StreamActionChunks(
|
|
216
|
+
self,
|
|
217
|
+
request_iterator: Iterator[robot_inference_pb2.ObservationRequest],
|
|
218
|
+
context: grpc.ServicerContext,
|
|
219
|
+
) -> Iterator[robot_inference_pb2.ActionChunkResponse]:
|
|
220
|
+
"""Handle streaming action chunk inference requests.
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
request_iterator: Iterator of ObservationRequest messages.
|
|
224
|
+
context: gRPC context.
|
|
225
|
+
|
|
226
|
+
Yields:
|
|
227
|
+
ActionChunkResponse messages for each observation.
|
|
228
|
+
"""
|
|
229
|
+
for request in request_iterator:
|
|
230
|
+
if context.is_active():
|
|
231
|
+
yield self.GetActionChunk(request, context)
|
|
232
|
+
else:
|
|
233
|
+
break
|
|
234
|
+
|
|
235
|
+
def HealthCheck(
|
|
236
|
+
self,
|
|
237
|
+
request: robot_inference_pb2.HealthCheckRequest,
|
|
238
|
+
context: grpc.ServicerContext,
|
|
239
|
+
) -> robot_inference_pb2.HealthCheckResponse:
|
|
240
|
+
"""Check server health and GPU status.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
request: HealthCheckRequest message.
|
|
244
|
+
context: gRPC context.
|
|
245
|
+
|
|
246
|
+
Returns:
|
|
247
|
+
HealthCheckResponse with server status.
|
|
248
|
+
"""
|
|
249
|
+
response = robot_inference_pb2.HealthCheckResponse()
|
|
250
|
+
response.healthy = True
|
|
251
|
+
response.status = "Server is running"
|
|
252
|
+
response.model_name = self.cfg.policy.type
|
|
253
|
+
response.device = str(self.device)
|
|
254
|
+
|
|
255
|
+
if torch.cuda.is_available():
|
|
256
|
+
response.gpu_memory_used_gb = torch.cuda.memory_allocated() / 1e9
|
|
257
|
+
response.gpu_memory_total_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
|
|
258
|
+
|
|
259
|
+
return response
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
def serve(cfg: TrainPipelineConfig):
|
|
263
|
+
"""Start the gRPC server.
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
cfg: Training pipeline configuration including server settings.
|
|
267
|
+
"""
|
|
268
|
+
server_cfg = cfg.server
|
|
269
|
+
server = grpc.server(
|
|
270
|
+
futures.ThreadPoolExecutor(max_workers=server_cfg.max_workers),
|
|
271
|
+
options=[
|
|
272
|
+
("grpc.max_send_message_length", server_cfg.max_send_message_length),
|
|
273
|
+
("grpc.max_receive_message_length", server_cfg.max_receive_message_length),
|
|
274
|
+
],
|
|
275
|
+
)
|
|
276
|
+
|
|
277
|
+
servicer = RobotPolicyServicer(cfg)
|
|
278
|
+
robot_inference_pb2_grpc.add_RobotPolicyServiceServicer_to_server(servicer, server)
|
|
279
|
+
|
|
280
|
+
server.add_insecure_port(f"[::]:{server_cfg.port}")
|
|
281
|
+
server.start()
|
|
282
|
+
|
|
283
|
+
logger.info(f"Server started on port {server_cfg.port}")
|
|
284
|
+
logger.info(f"Policy: {cfg.policy.type}")
|
|
285
|
+
logger.info(f"Device: {servicer.device}")
|
|
286
|
+
logger.info(f"Max workers: {server_cfg.max_workers}")
|
|
287
|
+
logger.info("Waiting for requests...")
|
|
288
|
+
|
|
289
|
+
try:
|
|
290
|
+
server.wait_for_termination()
|
|
291
|
+
except KeyboardInterrupt:
|
|
292
|
+
logger.info("Shutting down server...")
|
|
293
|
+
server.stop(grace=5)
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
@parser.wrap()
|
|
297
|
+
def server_main(cfg: TrainPipelineConfig):
|
|
298
|
+
"""Main entry point for the gRPC server.
|
|
299
|
+
|
|
300
|
+
Args:
|
|
301
|
+
cfg: Training pipeline configuration parsed from CLI/config file.
|
|
302
|
+
"""
|
|
303
|
+
logging.info(pformat(asdict(cfg)))
|
|
304
|
+
|
|
305
|
+
if cfg.seed is not None:
|
|
306
|
+
set_seed(cfg.seed)
|
|
307
|
+
|
|
308
|
+
serve(cfg)
|
|
309
|
+
|
|
310
|
+
|
|
311
|
+
if __name__ == "__main__":
|
|
312
|
+
init_logging()
|
|
313
|
+
server_main()
|
opentau/scripts/launch.py
CHANGED
|
@@ -18,10 +18,6 @@ import sys
|
|
|
18
18
|
from pathlib import Path
|
|
19
19
|
from types import ModuleType
|
|
20
20
|
|
|
21
|
-
import opentau.scripts.eval as eval_script
|
|
22
|
-
import opentau.scripts.export_to_onnx as export_script
|
|
23
|
-
import opentau.scripts.train as train_script
|
|
24
|
-
|
|
25
21
|
|
|
26
22
|
def launch(script_module: ModuleType, description: str, use_accelerate: bool = True):
|
|
27
23
|
"""Generic launcher for OpenTau scripts using Accelerate or Python."""
|
|
@@ -68,12 +64,24 @@ def launch(script_module: ModuleType, description: str, use_accelerate: bool = T
|
|
|
68
64
|
|
|
69
65
|
|
|
70
66
|
def train():
|
|
67
|
+
import opentau.scripts.train as train_script
|
|
68
|
+
|
|
71
69
|
launch(train_script, "Launch OpenTau training with Accelerate")
|
|
72
70
|
|
|
73
71
|
|
|
74
72
|
def eval():
|
|
73
|
+
import opentau.scripts.eval as eval_script
|
|
74
|
+
|
|
75
75
|
launch(eval_script, "Launch OpenTau evaluation with Accelerate")
|
|
76
76
|
|
|
77
77
|
|
|
78
78
|
def export():
|
|
79
|
+
import opentau.scripts.export_to_onnx as export_script
|
|
80
|
+
|
|
79
81
|
launch(export_script, "Launch OpenTau ONNX export", use_accelerate=False)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def visualize():
|
|
85
|
+
import opentau.scripts.visualize_dataset as visualize_script
|
|
86
|
+
|
|
87
|
+
launch(visualize_script, "Launch OpenTau visualization", use_accelerate=False)
|